{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ecf642c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tue Dec  5 08:08:22 2023       \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\r\n",
      "|-------------------------------+----------------------+----------------------+\r\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\r\n",
      "|                               |                      |               MIG M. |\r\n",
      "|===============================+======================+======================|\r\n",
      "|   0  NVIDIA A100 80G...  On   | 00000001:00:00.0 Off |                    0 |\r\n",
      "| N/A   46C    P0    66W / 300W |  22394MiB / 81920MiB |      0%      Default |\r\n",
      "|                               |                      |             Disabled |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "                                                                               \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| Processes:                                                                  |\r\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\r\n",
      "|        ID   ID                                                   Usage      |\r\n",
      "|=============================================================================|\r\n",
      "+-----------------------------------------------------------------------------+\r\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d4c3a83e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"WANDB_DISABLED\"] = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9b310aa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2023, Albert Gu, Tri Dao.\n",
    "\n",
    "import math\n",
    "from functools import partial\n",
    "\n",
    "from collections import namedtuple\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from transformers import PretrainedConfig\n",
    "\n",
    "from mamba_ssm.modules.mamba_simple import Mamba, Block\n",
    "from mamba_ssm.utils.generation import GenerationMixin\n",
    "from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\n",
    "\n",
    "try:\n",
    "    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn\n",
    "except ImportError:\n",
    "    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n",
    "\n",
    "\n",
    "def create_block(\n",
    "    d_model,\n",
    "    ssm_cfg=None,\n",
    "    norm_epsilon=1e-5,\n",
    "    rms_norm=False,\n",
    "    residual_in_fp32=False,\n",
    "    fused_add_norm=False,\n",
    "    layer_idx=None,\n",
    "    device=None,\n",
    "    dtype=None,\n",
    "):\n",
    "    if ssm_cfg is None:\n",
    "        ssm_cfg = {}\n",
    "    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)\n",
    "    norm_cls = partial(\n",
    "        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs\n",
    "    )\n",
    "    block = Block(\n",
    "        d_model,\n",
    "        mixer_cls,\n",
    "        norm_cls=norm_cls,\n",
    "        fused_add_norm=fused_add_norm,\n",
    "        residual_in_fp32=residual_in_fp32,\n",
    "    )\n",
    "    block.layer_idx = layer_idx\n",
    "    return block\n",
    "\n",
    "\n",
    "# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454\n",
    "def _init_weights(\n",
    "    module,\n",
    "    n_layer,\n",
    "    initializer_range=0.02,  # Now only used for embedding layer.\n",
    "    rescale_prenorm_residual=True,\n",
    "    n_residuals_per_layer=1,  # Change to 2 if we have MLP\n",
    "):\n",
    "    if isinstance(module, nn.Linear):\n",
    "        if module.bias is not None:\n",
    "            if not getattr(module.bias, \"_no_reinit\", False):\n",
    "                nn.init.zeros_(module.bias)\n",
    "    elif isinstance(module, nn.Embedding):\n",
    "        nn.init.normal_(module.weight, std=initializer_range)\n",
    "\n",
    "    if rescale_prenorm_residual:\n",
    "        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n",
    "        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n",
    "        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n",
    "        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n",
    "        #\n",
    "        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n",
    "        for name, p in module.named_parameters():\n",
    "            if name in [\"out_proj.weight\", \"fc2.weight\"]:\n",
    "                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n",
    "                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)\n",
    "                # We need to reinit p since this code could be called multiple times\n",
    "                # Having just p *= scale would repeatedly scale it down\n",
    "                nn.init.kaiming_uniform_(p, a=math.sqrt(5))\n",
    "                with torch.no_grad():\n",
    "                    p /= math.sqrt(n_residuals_per_layer * n_layer)\n",
    "\n",
    "\n",
    "class MixerModel(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        d_model: int,\n",
    "        n_layer: int,\n",
    "        vocab_size: int,\n",
    "        ssm_cfg=None,\n",
    "        norm_epsilon: float = 1e-5,\n",
    "        rms_norm: bool = False,\n",
    "        initializer_cfg=None,\n",
    "        fused_add_norm=False,\n",
    "        residual_in_fp32=False,\n",
    "        device=None,\n",
    "        dtype=None,\n",
    "    ) -> None:\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        super().__init__()\n",
    "        self.residual_in_fp32 = residual_in_fp32\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)\n",
    "\n",
    "        # We change the order of residual and layer norm:\n",
    "        # Instead of LN -> Attn / MLP -> Add, we do:\n",
    "        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and\n",
    "        # the main branch (output of MLP / Mixer). The model definition is unchanged.\n",
    "        # This is for performance reason: we can fuse add + layer_norm.\n",
    "        self.fused_add_norm = fused_add_norm\n",
    "        if self.fused_add_norm:\n",
    "            if layer_norm_fn is None or rms_norm_fn is None:\n",
    "                raise ImportError(\"Failed to import Triton LayerNorm / RMSNorm kernels\")\n",
    "\n",
    "        self.layers = nn.ModuleList(\n",
    "            [\n",
    "                create_block(\n",
    "                    d_model,\n",
    "                    ssm_cfg=ssm_cfg,\n",
    "                    norm_epsilon=norm_epsilon,\n",
    "                    rms_norm=rms_norm,\n",
    "                    residual_in_fp32=residual_in_fp32,\n",
    "                    fused_add_norm=fused_add_norm,\n",
    "                    layer_idx=i,\n",
    "                    **factory_kwargs,\n",
    "                )\n",
    "                for i in range(n_layer)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(\n",
    "            d_model, eps=norm_epsilon, **factory_kwargs\n",
    "        )\n",
    "\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return {\n",
    "            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "            for i, layer in enumerate(self.layers)\n",
    "        }\n",
    "\n",
    "    def forward(self, input_ids, inference_params=None):\n",
    "        hidden_states = self.embedding(input_ids)\n",
    "        residual = None\n",
    "        for layer in self.layers:\n",
    "            hidden_states, residual = layer(\n",
    "                hidden_states, residual, inference_params=inference_params\n",
    "            )\n",
    "        if not self.fused_add_norm:\n",
    "            residual = (hidden_states + residual) if residual is not None else hidden_states\n",
    "            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n",
    "        else:\n",
    "            # Set prenorm=False here since we don't need the residual\n",
    "            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn\n",
    "            hidden_states = fused_add_norm_fn(\n",
    "                hidden_states,\n",
    "                self.norm_f.weight,\n",
    "                self.norm_f.bias,\n",
    "                eps=self.norm_f.eps,\n",
    "                residual=residual,\n",
    "                prenorm=False,\n",
    "                residual_in_fp32=self.residual_in_fp32,\n",
    "            )\n",
    "        return hidden_states\n",
    "\n",
    "\n",
    "class MambaLMHeadModel(nn.Module, GenerationMixin):\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        d_model: int,\n",
    "        n_layer: int,\n",
    "        vocab_size: int,\n",
    "        initializer_cfg=None,\n",
    "        pad_vocab_size_multiple: int = 1,\n",
    "        device=None,\n",
    "        dtype=None,\n",
    "        **backbone_kwargs,\n",
    "    ) -> None:\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        super().__init__()\n",
    "        if vocab_size % pad_vocab_size_multiple != 0:\n",
    "            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n",
    "        self.backbone = MixerModel(\n",
    "            d_model=d_model,\n",
    "            n_layer=n_layer,\n",
    "            vocab_size=vocab_size,\n",
    "            initializer_cfg=initializer_cfg,\n",
    "            **backbone_kwargs,\n",
    "            **factory_kwargs,\n",
    "        )\n",
    "        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "        self.tie_weights()\n",
    "        self.config = PretrainedConfig(\n",
    "            d_model = d_model,\n",
    "            n_layer = n_layer,\n",
    "            vocab_size = vocab_size,\n",
    "            hidden_size = d_model,\n",
    "        )\n",
    "\n",
    "    def tie_weights(self):\n",
    "        self.lm_head.weight = self.backbone.embedding.weight\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "\n",
    "    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, labels = None):\n",
    "        \"\"\"\n",
    "        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n",
    "        num_last_tokens: if > 0, only return the logits for the last n tokens\n",
    "        \"\"\"\n",
    "        hidden_states = self.backbone(input_ids, inference_params=inference_params)\n",
    "        if num_last_tokens > 0:\n",
    "            hidden_states = hidden_states[:, -num_last_tokens:]\n",
    "        lm_logits = self.lm_head(hidden_states)\n",
    "        \n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            logits = lm_logits\n",
    "            # Shift so that tokens < n predict n\n",
    "            shift_logits = logits[..., :-1, :].contiguous()\n",
    "            shift_labels = labels[..., 1:].contiguous()\n",
    "            # Flatten the tokens\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n",
    "            shift_labels = shift_labels.view(-1)\n",
    "            # Enable model parallelism\n",
    "            shift_labels = shift_labels.to(shift_logits.device)\n",
    "            loss = loss_fct(shift_logits, shift_labels)\n",
    "            CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\", \"loss\"])\n",
    "            print(loss)\n",
    "            return CausalLMOutput(logits=lm_logits, loss = loss)\n",
    "            \n",
    "        else:\n",
    "            CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n",
    "            return CausalLMOutput(logits=lm_logits)\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):\n",
    "        config = load_config_hf(pretrained_model_name)\n",
    "        model = cls(**config, device=device, dtype=dtype, **kwargs)\n",
    "        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))\n",
    "        return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4fe72f34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/state-spaces/mamba-130m/raw/main/config.json -O config-130m.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eac72d2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('config-130m.json') as fopen:\n",
    "    config = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b85b48be",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MambaLMHeadModel(**{**config, 'vocab_size': 32000}, dtype = torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1b9ad22c",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d27ebaaa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32000"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.vocab_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e8beea8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from streaming import LocalDataset\n",
    "import numpy as np\n",
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "\n",
    "class UInt16(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.uint16)\n",
    "\n",
    "_encodings['uint16'] = UInt16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "74c0f1fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !git lfs clone https://huggingface.co/datasets/malaysia-ai/mosaic-instructions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "88f14bc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DatasetFixed(torch.utils.data.Dataset):\n",
    "    def __init__(self, local):\n",
    "        self.dataset = LocalDataset(local=local)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.dataset[idx]\n",
    "        data['labels'] = data['input_ids'].copy()\n",
    "\n",
    "        data.pop('token_type_ids', None)\n",
    "        for k in data.keys():\n",
    "            data[k] = data[k].astype(np.int64)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "train_dataset = DatasetFixed(local='mosaic-instructions')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "765e5457",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import default_data_collator\n",
    "\n",
    "b = [train_dataset[i] for i in range(2)]\n",
    "b = default_data_collator(b)\n",
    "for k in b:\n",
    "    b[k] = b[k].cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "569442d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "95ccfa78",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = Adam(model.parameters(),1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "889f246f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(10.4375, device='cuda:0', dtype=torch.bfloat16,\n",
      "       grad_fn=<NllLossBackward0>)\n",
      "0 tensor(10.4375, device='cuda:0', dtype=torch.bfloat16,\n",
      "       grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "2 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "4 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "6 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "8 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "10 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "12 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "14 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "16 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n",
      "18 tensor(nan, device='cuda:0', dtype=torch.bfloat16, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "batch_size = 2\n",
    "for i in range(0, 20, batch_size):\n",
    "    b = [train_dataset[i + k] for k in range(batch_size)]\n",
    "    b = default_data_collator(b)\n",
    "    for k in b:\n",
    "        b[k] = b[k].cuda()\n",
    "    optimizer.zero_grad()\n",
    "    loss = model(**b).loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(i, loss)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
